from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import f1_score, accuracy_score
import torch.utils.data as data
import numpy as np
import diffprivlib.models as dp

def _disjoint_datasets(data: np.array, labels: np.array, percentage: list, shuffle: bool = True):
    """
    If the params['disjoint'] flag is turned on, split up training data into disjoint partitions, proportional to
    the percentages input argument.
    Args:
        param data: data
        param labels: labels corresponding to data
        param percentage: list of percentages for each value, example [0.9, 0.02, 0.08] to get 90% train, 2% val and 8% test.
        param shuffle: Shuffle dataset before split.
    Returns: tuple of two lists of size = len(percentage), one with data x and other with labels y.
    """
    x_test = data
    y_test = labels
    percentage = list(percentage)  # need it to be mutable
    # assert sum(percentage) == 1., f"percentage must add to 1, but it adds to sum{percentage} = {sum(percentage)}"
    x = []
    y = []
    for i, per in enumerate(percentage[:-1]):
        x_train, x_test, y_train, y_test = train_test_split(x_test, y_test, test_size=1-per, shuffle=shuffle)
        percentage[i+1:] = [value / (1-percentage[i]) for value in percentage[i+1:]]
        x.append(x_train)
        y.append(y_train)
    x.append(x_test)
    y.append(y_test)
    return x, y


def fit_linear_model(X_train: np.array, X_test: np.array, Y_train: np.array, Y_test: np.array, params: dict, loss: str = 'log'):
    if params['ensemble']:
        # train params['n_ensemble'] models, store them in clf or clf_dp arrays, depending on whether DP flag is turned on
        clf = []
        clf_dp = []
        if params['disjoint']:
            percentages = [1/params['n_splits']] * params['n_splits']
            # split data into params['n_splits'] number of disjoint partitions
            X_trains, Y_trains = _disjoint_datasets(X_train, Y_train, percentage=percentages)
            for i in range(params['n_splits']):
                X_train_, Y_train_ = X_trains[i], Y_trains[i]
                # set up pipelines for both non-DP and DP cases
                m = LogisticRegression(penalty=params['penalty'], C=params['C'], fit_intercept=True, max_iter=2000)
                # for the DP Logreg model, we use a MinMaxScaler() to bound sensitivity 
                # (the maximal extent to which a data entry can change between neighboring datasets)
                m_dp = Pipeline([
                        ('scaler', MinMaxScaler()),
                        ('clf', dp.LogisticRegression(epsilon=params['epsilon'], penalty=params['penalty'], C=params['C'], fit_intercept=True, max_iter=2000))
                    ])
                if params['dp_model']:
                    m_dp.fit(X_train_, Y_train_)
                    clf_dp.append(m_dp['clf'])
                else:
                    m.fit(X_train_, Y_train_)
                    clf.append(m)
        else:
            for i in range(params['n_ensemble']):
                # sample data with replacement across ensembles
                X_train_, X_hold, Y_train_, Y_hold = train_test_split(X_train, Y_train, 
                                                                      test_size=params['frac_ensemble'],
                                                                      random_state=567+i)
                # set up pipelines for both non-DP and DP cases
                m = LogisticRegression(penalty=params['penalty'], C=params['C'], fit_intercept=True, max_iter=2000)
                m_dp = Pipeline([
                    ('scaler', MinMaxScaler()),
                    ('clf', dp.LogisticRegression(epsilon=params['epsilon'], penalty=params['penalty'], C=params['C'], fit_intercept=True, max_iter=2000))
                ])
                if params['dp_MODEL']:
                    m_dp.fit(X_train_, Y_train_)
                    clf_dp.append(m_dp['clf'])
                else:
                    m.fit(X_train_, Y_train_)
                    clf.append(m)

        score_test = 0
        if params['dp_model']:
            if params['disjoint']:
                for i in range(params['n_splits']):
                    score_test += (1/params['n_splits']) * clf_dp[i].score(X_test, Y_test)
            else:
                for i in range(params['n_ensemble']):
                    score_test += (1/params['n_ensemble']) * clf_dp[i].score(X_test, Y_test)
            # output that will appear in our experiment pipeline notebook
            print('Training set accuracy on last ensemble model:', np.round_(clf_dp[-1].score(X_train_, Y_train_), 4))
            print('Test set accuracy across all models:', np.round_(score_test, 4))   
        else:
            if params['disjoint']:
                for i in range(params['n_splits']):
                    score_test += (1/params['n_splits']) * clf[i].score(X_test, Y_test)
            else:
                for i in range(params['n_ensemble']):
                    score_test += (1/params['n_ensemble']) * clf[i].score(X_test, Y_test)
            # output that will appear in our experiment pipeline notebook
            print('Training set accuracy on last ensemble model:', np.round_(clf[-1].score(X_train_, Y_train_), 4))
            print('Test set accuracy across all models:', np.round_(score_test, 4))   

    else:
    # if we do not do ensemble training
        if params['dp_model']:
            clf_dp = Pipeline([
                ('scaler', MinMaxScaler()),
                ('clf', dp.LogisticRegression(epsilon=params['epsilon'], penalty=params['penalty'], C=params['C'], fit_intercept=True, max_iter=2000))
            ])
            clf_dp.fit(X_train, Y_train)
            clf_dp = clf_dp['clf']

            print('Training set accuracy:', np.round_(clf_dp.score(X_train, Y_train), 4))
            print('Test set accuracy:', np.round_(clf_dp.score(X_test, Y_test), 4))
        else:
            clf = LogisticRegression(penalty=params['penalty'], C=params['C'], fit_intercept=True, max_iter=2000)
            clf.fit(X_train, Y_train)
            print('Training set accuracy:', np.round_(clf.score(X_train, Y_train), 4))
            print('Test set accuracy:', np.round_(clf.score(X_test, Y_test), 4))

    if params['dp_model']:
        return clf_dp
    else:
        return clf
